#!/usr/bin/env python  
#-*- coding:utf-8 _*-  
""" 
@author:hello_life 
@license: Apache Licence 
@file: albert_model.py 
@time: 2022/04/21
@software: PyCharm 
description:
"""
import sys,os
sys.path.insert(0,os.path.dirname(os.getcwd()))

import torch
import torch.nn as nn
from transformers import AlbertModel
from parameters.albert_config import Config

class Albert_model(nn.Module):
    def __init__(self,config):
        super(Albert_model,self).__init__()
        self.albert=AlbertModel.from_pretrained(config.model_path)
        self.l1=nn.Linear(768,config.num_class)

    def forward(self,x):
        """
        x为输入t
        :param x:
        :return:
        """
        x=self.albert(**x)
        output=self.l1(x.pooler_output)
        return output


if __name__ == '__main__':
    config=Config()
    model=Albert_model(config)
    for n,d in model.named_parameters():
        print(n)
        break